-
Notifications
You must be signed in to change notification settings - Fork 10
Work-Stealing-based Persistent Kernel #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces a work-stealing-based persistent GEMM kernel that dynamically allocates tile IDs across compute units instead of using fixed partitioning. The implementation uses per-XCD (chiplet) atomic counters to reduce contention compared to global atomic operations. The work-stealing kernel is exposed as an opt-in feature through a new work_stealing parameter in the matmul APIs.
Changes:
- Added
MatmulConfigclass to pre-allocate and manage GPU buffers for kernel launches (tile counters, stream-K locks/partials) - Implemented work-stealing kernel with per-XCD atomic tile counters in
persistent_gemm_work_stealing.py - Extended all matmul APIs with optional
work_stealingandconfigparameters to support the new kernel
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 22 comments.
Show a summary per file
| File | Description |
|---|---|
include/tritonblas/matmul.py |
Added MatmulConfig class for buffer management; integrated work_stealing parameter and ws_persistent_matmul kernel; refactored buffer allocation to use config objects |
include/tritonblas/kernels/persistent_gemm_work_stealing.py |
New work-stealing kernel implementation with per-XCD atomic counters and dynamic tile assignment |
include/tritonblas/kernels/__init__.py |
Exported ws_persistent_matmul kernel |
include/tritonblas/__init__.py |
Exported MatmulConfig and matmul_preamble to public API |
tests/test_work_stealing.py |
Standalone test with custom module loading to test work-stealing kernel correctness and performance |
benchmarks/benchmark_work_stealing.py |
Comprehensive benchmark comparing work-stealing against static persistent, stream-K, and torch.matmul |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Motivation
Dynamically take away tile ids instead of fixed partitioning.
Getting Started